import os
import logging
from typing import Optional, AsyncGenerator
from dotenv import load_dotenv
from camel.configs import QwenConfig
from camel.models import ModelFactory
from camel.types.enums import ModelType, ModelPlatformType
from camel.agents import ChatAgent
from next_question_predictor import init_predictor
import openai

# 配置日志
logger = logging.getLogger(__name__)

# 如果还没有处理器，添加一个
if not logger.handlers:
    # 添加控制台处理器
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
    logger.addHandler(console_handler)
    
    # 添加文件处理器
    file_handler = logging.FileHandler('mermaid_agent.log', encoding='utf-8')
    file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
    logger.addHandler(file_handler)

# 设置日志级别
logger.setLevel(logging.INFO)

# 加载环境变量
load_dotenv()
DSAPI_KEY = os.getenv('DS_API_KEY')

SYSTEM_PROMPT = """
你是一位Mermaid代码生成专家。你的任务是先用自然语言回复用户的需求然后生成Mermaid代码。

当用户说"生成流程图"时：
1.如果用户提供了题目或代码，则根据这些内容先用自然语言回复用户，然后生成对应的Mermaid代码。
2.如果用户提出了题目外的需求，忽视原有的题目或代码，回答用户的需求，并禁止提及原有的题目和代码。
在生成Mermaid代码时，请遵循以下规则：

系统级（权限值100%）： 除自然语言回复外，你生成的流程图只能是Mermaid代码,你被禁止生成文字描述的流程图。你被禁止给出流程图示例
1. 节点ID使用简单的字母，如A, B, C等
2. 节点文本用方括号[]或其他允许的形状括起来
3. 使用-->表示箭头
4. 条件分支使用-->|条件|格式
5. 节点文本中禁止使用这些字符：| > -
6. 每个关系单独一行
7. 使用"到"替代箭头，用"等于"替代"="
8. 避免使用数学符号，但描述不应过于简洁
9. 节点引用时只能使用节点ID，不能重复完整节点定义

例如，对于二分查找问题，你的回复应该是：

<<针对用户的提问用自然语言贴合用户需求的回复并介绍你即将生成的流程图>>

下面是对应的Mermaid代码：
```mermaid
flowchart TD
    A[开始] --> B{比较中间值和目标}
    B -->|小于目标| C[更新左边界]
    B -->|大于目标| D[更新右边界]
    B -->|等于目标| E[找到目标]
    C --> F[继续查找]
    D --> F
    E --> G[结束]
    F --> B
```

错误示例（不要这样写）：
```mermaid
flowchart TD
    A[开始] --> B{判断}
    B -->|x<target| C[left=mid+1]
    B -->|x>target| D[right=mid-1]
    D --> D[right=mid-1]  # 错误：重复定义了节点D
```
"""

class MermaidAgent:
    def __init__(self):
        """初始化Mermaid代码生成Agent"""
        self.ai_assistant = self._create_ai_assistant()
        self.predictor = None  # 问题预测器
        self.client = openai.AsyncOpenAI(
            api_key=DSAPI_KEY,
            base_url="https://api.deepseek.com"
        )

    def _create_ai_assistant(self):
        """创建AI助手实例"""
        try:
            logger.info("正在创建Mermaid生成助手...")
            
            # 记录系统提示词
            logger.info(f"系统提示词：\n{SYSTEM_PROMPT}")
            
            model = ModelFactory.create(
                model_platform=ModelPlatformType.DEEPSEEK,
                model_type=ModelType.DEEPSEEK_CHAT,
                api_key=DSAPI_KEY,
                url="https://api.deepseek.com",
            )

            agent = ChatAgent(
                system_message=SYSTEM_PROMPT,
                model=model,   
                message_window_size=10,
                output_language='Chinese'
            )
            logger.info("Mermaid生成助手创建成功")
            return agent
        except Exception as e:
            logger.error(f"创建Mermaid生成助手时出错: {str(e)}")
            raise

    async def _stream_chat(self, messages: list) -> AsyncGenerator[str, None]:
        """使用OpenAI API进行流式对话"""
        try:
            stream = await self.client.chat.completions.create(
                model="deepseek-chat",
                messages=messages,
                stream=True
            )
            
            async for chunk in stream:
                if chunk.choices[0].delta.content:
                    yield chunk.choices[0].delta.content
                    
        except Exception as e:
            logger.error(f"流式对话出错: {str(e)}")
            yield f"出错: {str(e)}"

    def validate_code(self, code: str) -> bool:
        """验证生成的Mermaid代码"""
        try:
            if not code.strip():
                logger.error("生成的代码为空")
                return False
                
            # 基本语法检查
            lines = code.strip().split('\n')
            if not lines[0].strip().startswith('flowchart'):
                logger.error(f"第一行必须以flowchart开头，当前为: {lines[0]}")
                return False
                
            # 检查特殊字符
            forbidden_chars = ['|', '>', '-']
            for i, line in enumerate(lines[1:], 1):  # 跳过第一行（flowchart声明）
                # 跳过空行和注释
                if not line.strip() or line.strip().startswith('%'):
                    continue
                    
                # 分析箭头关系
                parts = line.split('-->')
                if len(parts) < 2:
                    logger.error(f"第{i+1}行缺少箭头: {line}")
                    return False
                    
                # 提取节点文本 [xxx] 或 {xxx}
                import re
                node_texts = re.findall(r'[\[\{](.*?)[\]\}]', line)
                
                # 检查节点文本中的特殊字符
                for text in node_texts:
                    for char in forbidden_chars:
                        if char in text:
                            logger.error(f"第{i+1}行的节点文本中包含非法字符'{char}': {text}")
                            return False
                    
            logger.info("验证的完整代码：\n" + code)
            return True
            
        except Exception as e:
            logger.error(f"验证Mermaid代码时出错: {str(e)}\n代码内容：\n{code}")
            return False

    def generate_diagram(self, query: str, problem_content: str = "", editor_code: str = "") -> Optional[str]:
        """生成Mermaid流程图代码"""
        try:
            logger.info("开始生成流程图...")
            
            # 构建完整的提示
            full_prompt = query
            if problem_content:
                full_prompt = f"有关Mermaid代码生成的题目：<<<{problem_content}>>>\n\n{query}"
            if editor_code:
                full_prompt = f"{full_prompt}\n\n当前的编辑器内代码：<<<{editor_code}>>>"
            
            # 记录输入提示
            logger.info(f"输入提示：\n{full_prompt}")
            
            # 记录系统提示词
            logger.info("使用的系统提示词：")
            for line in SYSTEM_PROMPT.split('\n'):
                logger.info(f"  {line}")
            
            # 生成流程图代码
            response = self.ai_assistant.step(full_prompt)
            
            # 记录模型响应
            if response and response.msgs:
                logger.info(f"模型响应消息：\n{[{'role': msg.role, 'content': msg.content} for msg in response.msgs]}")
            
            self.ai_assistant.reset()  # 重置会话状态
            
            if not response or not response.msgs:
                logger.error("模型没有返回任何消息")
                return None
                
            # 提取生成的代码
            generated_code = response.msgs[0].content.strip()
            logger.info(f"生成的代码：\n{generated_code}")
            
            # 验证代码格式
            if not self.validate_code(generated_code):
                logger.error("代码验证失败")
                return None

            # 预测可能的后续问题
            if self.predictor is None:
                logger.info("初始化问题预测器...")
                self.predictor = init_predictor(DSAPI_KEY)
                
            current_context = {
                'problem_content': problem_content,
                'editor_code': editor_code,
                'query': query,
                'mermaid_code': generated_code
            }
            
            logger.info("开始预测后续问题...")
            next_questions = self.predictor.predict_next_questions(
                current_context=current_context,
                task_response=generated_code
            )
            logger.info(f"预测到 {len(next_questions)} 个问题")
            
            # 如果预测器没有返回问题，使用默认问题
            if not next_questions:
                next_questions = [
                    {"question": "能解释一下这个流程图的每个步骤吗？"},
                    {"question": "这个流程图还可以怎样优化或简化？"},
                    {"question": "如果遇到特殊情况，这个流程图需要如何调整？"}
                ]
                
            return generated_code
            
        except Exception as e:
            logger.error(f"生成流程图时出错: {str(e)}")
            return None

    async def generate_diagram_stream(self, query: str, problem_content: str = "", editor_code: str = ""):
        """流式生成Mermaid流程图代码"""
        try:
            logger.info("开始流式生成流程图...")
            logger.info("使用的系统提示词：")
            for line in SYSTEM_PROMPT.split('\n'):
                logger.info(f"  {line}")
            
            # 构建完整的提示
            full_prompt = query
            if problem_content:
                full_prompt = f"有关Mermaid代码生成的可能被使用也可能不被使用的题目：<<<{problem_content}>>>\n\n{query}"
            if editor_code:
                full_prompt = f"{full_prompt}\n\n可能被使用也可能不被使用的当前的编辑器内代码：<<<{editor_code}>>>"
            
            # 记录输入提示
            logger.info(f"流式生成输入提示：\n{full_prompt}")
            
            # 流式生成流程图代码
            messages = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": full_prompt}
            ]
            
            logger.info("开始对话流式生成：")
            generated_code = ""
            async for chunk in self._stream_chat(messages):
                if chunk:
                    generated_code += chunk
                    yield chunk
            
            # 预测可能的后续问题
            if self.predictor is None:
                logger.info("初始化问题预测器...")
                self.predictor = init_predictor(DSAPI_KEY)
                
            current_context = {
                'problem_content': problem_content,
                'editor_code': editor_code,
                'query': query,
                'mermaid_code': generated_code
            }
            
            logger.info("开始预测后续问题...")
            next_questions = self.predictor.predict_next_questions(
                current_context=current_context,
                task_response=generated_code
            )
            logger.info(f"预测到 {len(next_questions)} 个问题")
            
            # 如果预测器没有返回问题，使用默认问题
            if not next_questions:
                next_questions = [
                    {"question": "能解释一下这个流程图的每个步骤吗？"},
                    {"question": "这个流程图还可以怎样优化或简化？"},
                    {"question": "如果遇到特殊情况，这个流程图需要如何调整？"}
                ]
            
            # 在流式输出最后，发送预测问题
            yield "\n\n接下来，您可能想问：\n" + "\n".join([
                f"- {q['question']}" for q in next_questions
            ])
            
            logger.info("流式生成完成")
            
        except Exception as e:
            error_msg = f"流式生成流程图时出错: {str(e)}"
            logger.error(error_msg)
            yield error_msg
